TensorFlow/Keras

Keras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.

Note 1: This is not an introduction to deep neural networks as this would explode the scope of this notebook. But we want to show you how you can implement a convoluted neural network to classify neuroimages, in our case fMRI images.
Note 2: We want to thank Anisha Keshavan as a lot of the content in this notebook is coming from here introduction notebook about Keras.

Setup

In [1]:
import warnings
warnings.filterwarnings("ignore")

from nilearn import plotting
%matplotlib inline
import numpy as np
import nibabel as nb
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white")
import os
import datetime
import tensorflow as tf
import plotly.graph_objects as go
from plotly import figure_factory as ff

Load machine learning dataset

We will load a dataset that was prepared to enable quick showcases/introductions for machine learning. It includes an anatomical template image (we will need this for visualization), as well as a 4D fMRI image from a resting state scan. The dataset ist from Zang et al.. It contains 48 subjects, where each subject did two resting-state fMRI recordings. Once with eyes open and once with eyes closed. The data was already pre-processed and is already ready for the machine learning notebooks. Note: The data diverges from the original data in the way, that we only consider the first 100 volumes for this tutorial. The original dataset had 240 volumes per run.

In [2]:
anat = nb.load('data/MNI152_T1_1mm.nii.gz')
func = nb.load('data/dataset_ML.nii.gz')

Let's check how the 4D fMRI image looks like via plotting its mean over time.

In [3]:
from nilearn.image import mean_img
from nilearn.plotting import plot_anat
In [4]:
plot_anat(mean_img(func), cmap='magma', colorbar=False, display_mode='x', vmax=2, annotate=False,
          cut_coords=range(0, 49, 12), title='Mean value of machine learning dataset');

Specifying labels and chunks

As in every other machine or deep learning application, we need some chunks and label variables to train the neural network. The labels are important so that we can predict what we want to classify. And the chunks are just an easy way to make sure that the training and test dataset are split in an equal/balanced way.

So, as before, we specify again which volumes of the dataset were recorded during eyes closed resting state and which ones were recorded during eyes open resting state recording.

From the dataset release we know that we have a total of 384 volumes in our dataset_ML.nii.gz file and that it's always 4 volumes of the condition eyes closed, followed by 4 volumes of the condition eyes open, etc. Therefore our labels should be as follows:

In [5]:
labels = np.ravel([[['closed'] * 4, ['open'] * 4] for i in range(48)])
labels[:20]
Out[5]:
array(['closed', 'closed', 'closed', 'closed', 'open', 'open', 'open',
       'open', 'closed', 'closed', 'closed', 'closed', 'open', 'open',
       'open', 'open', 'closed', 'closed', 'closed', 'closed'],
      dtype='<U6')

Second, the chunks variable should not switch between subjects. So, as before, we can again specify 6 chunks of 64 volumes (8 subjects), each:

In [6]:
chunks = np.ravel([[i] * 64 for i in range(6)])
chunks[:150]
Out[6]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

Keras - 2D Example

Convoluted neural networks are very powerful (as you will see), but the computation power to train the models can be incredibly demanding. For this reason, it's sometimes recommended to try to reduce the input space if possible.

In our case, we could try to not train the neural network only on one very thin slab (a few slices) of the brain. So, instead of taking the data matrix of the whole brain, we just take 2 slices in the region that we think is most likely to be predictive for the question at hand.

We know (or suspect) that the regions with the most predictive power are probably somewhere around the eyes and in the visual cortex. So let's try to specify a few slices that cover those regions.

So, let's try to just take a few slices around the eyes:

In [7]:
plot_anat(mean_img(func).slicer[...,5:-25], cmap='magma', colorbar=False,
          display_mode='x', vmax=2, annotate=False, cut_coords=range(0, 49, 12),
          title='Slab of the machine learning mean image');

Hmm... That doesn't seem to work. We want to cover the eyes and the visual cortex. Like this, we're too far down in the back of the head (at the Cerebellum). One solution to this is to rotate the volume.

So let's do that:

In [8]:
# Rotation parameters
phi = 0.35
cos = np.cos(phi)
sin = np.sin(phi)

# Compute rotation matrix around x-axis
rotation_affine = np.array([[1, 0, 0, 0],
                            [0, cos, -sin, 0],
                            [0, sin, cos, 0],
                            [0, 0, 0, 1]])
new_affine = rotation_affine.dot(func.affine)
In [9]:
# Rotate and resample image to new orientation
from nilearn.image import resample_img
new_img = nb.Nifti1Image(func.get_fdata(), new_affine)
img_rot = resample_img(new_img, func.affine, interpolation='continuous')
del func
del new_img
In [10]:
# Delete zero-only rows and columns
from nilearn.image import crop_img
img_crop = crop_img(img_rot)
del img_rot

Let's check if the rotation worked.

In [11]:
plot_anat(mean_img(img_crop), cmap='magma', colorbar=False, display_mode='x', vmax=2, annotate=False,
          cut_coords=range(-20, 30, 12), title='Rotated machine learning dataset');

Perfect! And which slab should we take? Let's try the slices 12, 13 and 14.

In [12]:
from nilearn.plotting import plot_stat_map
img_slab = img_crop.slicer[..., 12:15, :]
plot_stat_map(mean_img(img_slab), cmap='magma', bg_img=mean_img(img_crop), colorbar=False,
              display_mode='x', vmax=2, annotate=False, cut_coords=range(-20, 30, 12),
              title='Slab of rotated machine learning dataset');

Perfect, the slab seems to contain exactly what we want. Now that the data is ready we can continue with the actual machine learning part.

Split data into a training and testing set

First things first, we need to define a training and testing set. This is really important because we need to make sure that our model can generalize to new, unseen data. Here, we randomly shuffle our data, and reserve 80% of it for our training data, and the remaining 20% for testing.

So let's first get the data in the right structure for keras. For this, we need to swap some of the dimensions of our data matrix.

In [13]:
data = np.rollaxis(img_slab.get_fdata(), 3, 0)
data.shape
Out[13]:
(384, 40, 56, 3)

As you can see, the goal is to have in the first dimension, the different volumes, and then the volume itself. Keep in mind, that the last dimension (here of size 2), are considered as channels in the keras model that we will be using below.

Note: To make this notebook reproducible, i.e. always leading to the "same" results. Let's set a seed point for the random split of the dataset. This should only be done for teaching purposes, but not for real research as randomness and chance are a crucial part of machine learning.

In [14]:
from numpy.random import seed
seed(0)

As a next step, let's create a index list that we can use to split the data and labels into training and test sets:

In [15]:
# Create list of indices and shuffle them
N = data.shape[0]
indices = np.arange(N)
np.random.shuffle(indices)

#  Cut the dataset at 80% to create the training and test set
N_80p = int(0.8 * N)
indices_train = indices[:N_80p]
indices_test = indices[N_80p:]

# Split the data into training and test sets
X_train = data[indices_train, ...]
X_test = data[indices_test, ...]

print(X_train.shape, X_test.shape)
(307, 40, 56, 3) (77, 40, 56, 3)

Create outcome variable

We need to define a variable that holds the outcome variable (1 or 0) that indicates whether or not the resting-state images were recorded with eyes opened or closed. Luckily we have this information already stored in the labels variable above. So let's split these labels in training and test set:

In [16]:
y_train = labels[indices_train] == 'open'
y_test = labels[indices_test] == 'open'

Data Scaling

In [17]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
In [18]:
scaler = StandardScaler()
pca = PCA()
tsne = TSNE()
In [19]:
X_scaled = scaler.fit_transform(X_train.reshape(len(X_train), -1))
In [20]:
X_pca = pca.fit_transform(X_scaled)
In [21]:
plt.plot(pca.explained_variance_ratio_.cumsum())
Out[21]:
[<matplotlib.lines.Line2D at 0x181fd2440>]
In [22]:
y_train
Out[22]:
array([ True,  True,  True,  True, False, False,  True, False,  True,
        True, False, False, False, False,  True,  True,  True, False,
       False, False,  True, False, False, False, False,  True,  True,
        True, False,  True,  True, False,  True,  True, False, False,
        True,  True,  True,  True, False, False, False, False, False,
        True, False,  True,  True,  True,  True,  True,  True, False,
        True,  True, False,  True, False, False, False,  True,  True,
       False, False, False,  True, False, False, False,  True,  True,
        True,  True, False,  True,  True, False, False, False, False,
       False,  True,  True,  True,  True, False,  True, False,  True,
        True, False, False,  True,  True,  True, False, False,  True,
        True,  True, False,  True,  True,  True,  True,  True, False,
       False,  True, False,  True,  True, False,  True,  True,  True,
       False, False,  True,  True, False,  True, False, False, False,
        True, False,  True,  True,  True, False, False,  True, False,
       False,  True,  True,  True, False,  True,  True, False,  True,
        True,  True, False, False, False,  True, False,  True, False,
        True, False,  True,  True,  True,  True,  True, False,  True,
        True, False, False,  True,  True, False,  True,  True, False,
       False, False, False,  True,  True,  True, False, False, False,
       False,  True,  True,  True, False, False,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
       False, False,  True, False, False,  True,  True, False, False,
        True, False, False, False, False,  True,  True,  True,  True,
        True, False, False,  True,  True,  True, False, False, False,
       False,  True, False,  True,  True, False, False,  True, False,
        True,  True,  True, False, False, False, False,  True,  True,
       False,  True, False, False,  True, False, False,  True, False,
       False,  True, False,  True,  True, False, False, False, False,
        True,  True, False, False, False,  True,  True, False,  True,
        True, False, False,  True, False, False,  True,  True,  True,
       False,  True,  True,  True, False, False,  True, False, False,
       False,  True, False, False, False,  True, False,  True, False,
       False,  True, False,  True,  True, False,  True,  True,  True,
       False])
In [23]:
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y_train, cmap='bwr')
Out[23]:
<matplotlib.collections.PathCollection at 0x1820488e0>
In [24]:
X_tsne = tsne.fit_transform(X_pca)
In [25]:
plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_train, cmap='bwr')
Out[25]:
<matplotlib.collections.PathCollection at 0x1820c04c0>
In [26]:
mean = X_train.mean(axis=0)
mean.shape
Out[26]:
(40, 56, 3)
In [27]:
std = X_train.std(axis=0)
std.shape
Out[27]:
(40, 56, 3)
In [28]:
plt.hist(np.ravel(std), bins=100);
plt.vlines(0.05, 0, 1000, colors='red')
Out[28]:
<matplotlib.collections.LineCollection at 0x182048e50>
In [29]:
std[std<0.05] = 0
In [30]:
plt.hist(np.ravel(mean), bins=100);
plt.vlines(0.25, 0, 1000, colors='red')
Out[30]:
<matplotlib.collections.LineCollection at 0x182203af0>
In [31]:
mean[mean<0.05] = 0
In [32]:
mask = (mean*std)!=0
In [33]:
X_zscore_tr = (X_train-mean)/std
X_zscore_te = (X_test-mean)/std
X_zscore_tr.shape
Out[33]:
(307, 40, 56, 3)
In [34]:
X_zscore_tr[np.isnan(X_zscore_tr)]=0
X_zscore_te[np.isnan(X_zscore_te)]=0
In [35]:
X_zscore_tr[np.isinf(X_zscore_tr)]=0
X_zscore_te[np.isinf(X_zscore_te)]=0

And now we're good to go.

Creating a Sequential Model

Now come the fun and tricky part. We need to specify the structure of our convoluted neural network. As a quick reminder, a convoluted neural network consists of some convolution layers, pooling layers, some flattening layers and some full connect layers:

Taken from: https://www.mathworks.com/videos/introduction-to-deep-learning-what-are-convolutional-neural-networks--1489512765771.html

So as a first step, let's import all modules that we need to create the keras model:

In [36]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AvgPool2D, BatchNormalization
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
from tensorflow.keras.optimizers import Adam, SGD

As a next step, we should specify some of the model parameters that we want to be identical throughout the model:

In [37]:
# Get shape of input data
data_shape = tuple(X_train.shape[1:])

# Specify shape of convolution kernel
kernel_size = (3, 3)

# Specify number of output categories
n_classes = 2

Now comes the big part... the model, i.e. the structure of the neural network! We want to make clear that we're no experts in deep neural networks and therefore, the model below might not necessarily be a good model. But we chose it as it can be rather quickly estimated and has rather few parameters to estimate.

In [38]:
# Specify number of filters per layer
filters = 32

model = Sequential()

model.add(Conv2D(filters, kernel_size, activation='relu', input_shape=data_shape))
model.add(BatchNormalization())
model.add(MaxPooling2D())
filters *= 2

model.add(Conv2D(filters, kernel_size, activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
filters *= 2

model.add(Conv2D(filters, kernel_size, activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
filters *= 2

model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(1024, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(n_classes, activation='softmax'))
2022-05-25 11:39:45.542956: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
In [39]:
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam', # swap out for sgd 
              metrics=['accuracy'])

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 38, 54, 32)        896       
                                                                 
 batch_normalization (BatchN  (None, 38, 54, 32)       128       
 ormalization)                                                   
                                                                 
 max_pooling2d (MaxPooling2D  (None, 19, 27, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 17, 25, 64)        18496     
                                                                 
 batch_normalization_1 (Batc  (None, 17, 25, 64)       256       
 hNormalization)                                                 
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 8, 12, 64)        0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 6, 10, 128)        73856     
                                                                 
 batch_normalization_2 (Batc  (None, 6, 10, 128)       512       
 hNormalization)                                                 
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 3, 5, 128)        0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 1920)              0         
                                                                 
 dropout (Dropout)           (None, 1920)              0         
                                                                 
 dense (Dense)               (None, 1024)              1967104   
                                                                 
 batch_normalization_3 (Batc  (None, 1024)             4096      
 hNormalization)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 1024)              0         
                                                                 
 dense_1 (Dense)             (None, 256)               262400    
                                                                 
 batch_normalization_4 (Batc  (None, 256)              1024      
 hNormalization)                                                 
                                                                 
 dropout_2 (Dropout)         (None, 256)               0         
                                                                 
 dense_2 (Dense)             (None, 64)                16448     
                                                                 
 batch_normalization_5 (Batc  (None, 64)               256       
 hNormalization)                                                 
                                                                 
 dropout_3 (Dropout)         (None, 64)                0         
                                                                 
 dense_3 (Dense)             (None, 2)                 130       
                                                                 
=================================================================
Total params: 2,345,602
Trainable params: 2,342,466
Non-trainable params: 3,136
_________________________________________________________________

That's what our model looks like! Cool!

Fitting the Model

The next step is now, of course, to fit our model to the training data. In our case we have two parameters that we can work with:

First: How many iterations of the model fitting should be computed

In [40]:
nEpochs = 100  # Increase this value for better results (i.e., more training)

Second: How many elements (volumes) should be considered at once for the updating of the weights?

In [41]:
batch_size = 32   # Increasing this value might speed up fitting

We will also define a log directory so that we can evaluate our model later on as best as possible.

In [42]:
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

So let's test the model:

In [43]:
%time fit = model.fit(X_zscore_tr, y_train, epochs=nEpochs, batch_size=batch_size,\
                      validation_split=0.2, callbacks=[tensorboard_callback])
Epoch 1/100
8/8 [==============================] - 2s 153ms/step - loss: 1.3926 - accuracy: 0.5020 - val_loss: 0.6548 - val_accuracy: 0.6452
Epoch 2/100
8/8 [==============================] - 1s 124ms/step - loss: 0.9597 - accuracy: 0.6041 - val_loss: 0.6324 - val_accuracy: 0.7258
Epoch 3/100
8/8 [==============================] - 1s 125ms/step - loss: 0.7451 - accuracy: 0.6939 - val_loss: 0.6340 - val_accuracy: 0.6129
Epoch 4/100
8/8 [==============================] - 1s 130ms/step - loss: 0.6463 - accuracy: 0.7102 - val_loss: 0.6153 - val_accuracy: 0.7419
Epoch 5/100
8/8 [==============================] - 1s 119ms/step - loss: 0.6486 - accuracy: 0.7224 - val_loss: 0.5925 - val_accuracy: 0.7742
Epoch 6/100
8/8 [==============================] - 1s 128ms/step - loss: 0.6002 - accuracy: 0.7265 - val_loss: 0.5590 - val_accuracy: 0.8065
Epoch 7/100
8/8 [==============================] - 1s 125ms/step - loss: 0.4804 - accuracy: 0.8041 - val_loss: 0.5460 - val_accuracy: 0.7742
Epoch 8/100
8/8 [==============================] - 1s 150ms/step - loss: 0.4568 - accuracy: 0.8245 - val_loss: 0.5388 - val_accuracy: 0.8065
Epoch 9/100
8/8 [==============================] - 1s 134ms/step - loss: 0.3595 - accuracy: 0.8571 - val_loss: 0.5440 - val_accuracy: 0.8387
Epoch 10/100
8/8 [==============================] - 1s 151ms/step - loss: 0.3120 - accuracy: 0.8816 - val_loss: 0.5462 - val_accuracy: 0.7742
Epoch 11/100
8/8 [==============================] - 1s 145ms/step - loss: 0.3281 - accuracy: 0.8612 - val_loss: 0.5513 - val_accuracy: 0.7097
Epoch 12/100
8/8 [==============================] - 1s 137ms/step - loss: 0.2679 - accuracy: 0.8980 - val_loss: 0.5719 - val_accuracy: 0.6129
Epoch 13/100
8/8 [==============================] - 1s 132ms/step - loss: 0.2698 - accuracy: 0.8939 - val_loss: 0.6865 - val_accuracy: 0.5161
Epoch 14/100
8/8 [==============================] - 1s 129ms/step - loss: 0.1656 - accuracy: 0.9265 - val_loss: 0.8108 - val_accuracy: 0.5161
Epoch 15/100
8/8 [==============================] - 1s 140ms/step - loss: 0.1833 - accuracy: 0.9347 - val_loss: 0.9223 - val_accuracy: 0.5000
Epoch 16/100
8/8 [==============================] - 1s 143ms/step - loss: 0.1648 - accuracy: 0.9429 - val_loss: 0.9505 - val_accuracy: 0.5000
Epoch 17/100
8/8 [==============================] - 1s 146ms/step - loss: 0.1720 - accuracy: 0.9429 - val_loss: 0.8843 - val_accuracy: 0.4839
Epoch 18/100
8/8 [==============================] - 1s 135ms/step - loss: 0.0830 - accuracy: 0.9714 - val_loss: 0.8557 - val_accuracy: 0.4839
Epoch 19/100
8/8 [==============================] - 1s 142ms/step - loss: 0.1872 - accuracy: 0.9306 - val_loss: 0.8204 - val_accuracy: 0.5161
Epoch 20/100
8/8 [==============================] - 1s 136ms/step - loss: 0.1205 - accuracy: 0.9633 - val_loss: 0.8903 - val_accuracy: 0.5000
Epoch 21/100
8/8 [==============================] - 1s 139ms/step - loss: 0.1398 - accuracy: 0.9429 - val_loss: 0.9782 - val_accuracy: 0.5000
Epoch 22/100
8/8 [==============================] - 1s 130ms/step - loss: 0.1030 - accuracy: 0.9592 - val_loss: 1.0515 - val_accuracy: 0.5000
Epoch 23/100
8/8 [==============================] - 1s 148ms/step - loss: 0.0858 - accuracy: 0.9633 - val_loss: 1.0652 - val_accuracy: 0.5161
Epoch 24/100
8/8 [==============================] - 1s 201ms/step - loss: 0.1055 - accuracy: 0.9510 - val_loss: 0.9735 - val_accuracy: 0.5323
Epoch 25/100
8/8 [==============================] - 1s 160ms/step - loss: 0.0699 - accuracy: 0.9714 - val_loss: 0.8374 - val_accuracy: 0.5806
Epoch 26/100
8/8 [==============================] - 1s 151ms/step - loss: 0.0457 - accuracy: 0.9878 - val_loss: 0.7892 - val_accuracy: 0.5968
Epoch 27/100
8/8 [==============================] - 1s 148ms/step - loss: 0.0761 - accuracy: 0.9755 - val_loss: 0.7068 - val_accuracy: 0.6129
Epoch 28/100
8/8 [==============================] - 1s 143ms/step - loss: 0.0493 - accuracy: 0.9837 - val_loss: 0.6480 - val_accuracy: 0.6452
Epoch 29/100
8/8 [==============================] - 1s 138ms/step - loss: 0.1099 - accuracy: 0.9592 - val_loss: 0.6517 - val_accuracy: 0.6129
Epoch 30/100
8/8 [==============================] - 1s 163ms/step - loss: 0.0493 - accuracy: 0.9878 - val_loss: 0.6734 - val_accuracy: 0.6129
Epoch 31/100
8/8 [==============================] - 1s 168ms/step - loss: 0.0295 - accuracy: 0.9959 - val_loss: 0.7009 - val_accuracy: 0.6452
Epoch 32/100
8/8 [==============================] - 1s 171ms/step - loss: 0.0872 - accuracy: 0.9714 - val_loss: 0.6840 - val_accuracy: 0.6452
Epoch 33/100
8/8 [==============================] - 1s 150ms/step - loss: 0.0352 - accuracy: 0.9959 - val_loss: 0.6040 - val_accuracy: 0.6774
Epoch 34/100
8/8 [==============================] - 1s 145ms/step - loss: 0.0497 - accuracy: 0.9796 - val_loss: 0.5444 - val_accuracy: 0.6935
Epoch 35/100
8/8 [==============================] - 1s 152ms/step - loss: 0.0891 - accuracy: 0.9673 - val_loss: 0.4700 - val_accuracy: 0.7742
Epoch 36/100
8/8 [==============================] - 1s 137ms/step - loss: 0.0537 - accuracy: 0.9837 - val_loss: 0.4536 - val_accuracy: 0.7903
Epoch 37/100
8/8 [==============================] - 1s 147ms/step - loss: 0.0260 - accuracy: 0.9959 - val_loss: 0.4615 - val_accuracy: 0.7903
Epoch 38/100
8/8 [==============================] - 1s 144ms/step - loss: 0.0186 - accuracy: 0.9959 - val_loss: 0.4437 - val_accuracy: 0.8226
Epoch 39/100
8/8 [==============================] - 1s 148ms/step - loss: 0.0380 - accuracy: 0.9837 - val_loss: 0.4077 - val_accuracy: 0.8065
Epoch 40/100
8/8 [==============================] - 1s 143ms/step - loss: 0.0128 - accuracy: 1.0000 - val_loss: 0.3810 - val_accuracy: 0.8387
Epoch 41/100
8/8 [==============================] - 1s 135ms/step - loss: 0.0376 - accuracy: 0.9837 - val_loss: 0.3721 - val_accuracy: 0.8548
Epoch 42/100
8/8 [==============================] - 1s 147ms/step - loss: 0.0119 - accuracy: 1.0000 - val_loss: 0.3940 - val_accuracy: 0.8548
Epoch 43/100
8/8 [==============================] - 1s 134ms/step - loss: 0.0182 - accuracy: 1.0000 - val_loss: 0.4053 - val_accuracy: 0.8548
Epoch 44/100
8/8 [==============================] - 1s 136ms/step - loss: 0.0365 - accuracy: 0.9878 - val_loss: 0.3956 - val_accuracy: 0.8387
Epoch 45/100
8/8 [==============================] - 1s 143ms/step - loss: 0.0146 - accuracy: 0.9959 - val_loss: 0.3969 - val_accuracy: 0.8387
Epoch 46/100
8/8 [==============================] - 1s 156ms/step - loss: 0.0413 - accuracy: 0.9837 - val_loss: 0.3770 - val_accuracy: 0.8387
Epoch 47/100
8/8 [==============================] - 1s 156ms/step - loss: 0.0255 - accuracy: 0.9918 - val_loss: 0.3671 - val_accuracy: 0.8548
Epoch 48/100
8/8 [==============================] - 1s 183ms/step - loss: 0.0245 - accuracy: 0.9918 - val_loss: 0.3529 - val_accuracy: 0.8387
Epoch 49/100
8/8 [==============================] - 1s 151ms/step - loss: 0.0537 - accuracy: 0.9796 - val_loss: 0.3366 - val_accuracy: 0.8710
Epoch 50/100
8/8 [==============================] - 1s 140ms/step - loss: 0.0158 - accuracy: 1.0000 - val_loss: 0.3621 - val_accuracy: 0.8387
Epoch 51/100
8/8 [==============================] - 1s 142ms/step - loss: 0.0309 - accuracy: 0.9878 - val_loss: 0.3704 - val_accuracy: 0.8387
Epoch 52/100
8/8 [==============================] - 1s 156ms/step - loss: 0.0151 - accuracy: 0.9918 - val_loss: 0.3585 - val_accuracy: 0.8548
Epoch 53/100
8/8 [==============================] - 1s 154ms/step - loss: 0.0172 - accuracy: 0.9918 - val_loss: 0.3300 - val_accuracy: 0.8871
Epoch 54/100
8/8 [==============================] - 1s 163ms/step - loss: 0.0300 - accuracy: 0.9878 - val_loss: 0.3353 - val_accuracy: 0.8871
Epoch 55/100
8/8 [==============================] - 1s 145ms/step - loss: 0.0144 - accuracy: 0.9918 - val_loss: 0.3379 - val_accuracy: 0.8710
Epoch 56/100
8/8 [==============================] - 1s 161ms/step - loss: 0.0313 - accuracy: 0.9878 - val_loss: 0.3412 - val_accuracy: 0.8710
Epoch 57/100
8/8 [==============================] - 1s 141ms/step - loss: 0.0272 - accuracy: 0.9918 - val_loss: 0.3223 - val_accuracy: 0.8710
Epoch 58/100
8/8 [==============================] - 1s 141ms/step - loss: 0.0123 - accuracy: 0.9959 - val_loss: 0.3440 - val_accuracy: 0.8548
Epoch 59/100
8/8 [==============================] - 1s 163ms/step - loss: 0.0206 - accuracy: 0.9959 - val_loss: 0.3426 - val_accuracy: 0.8387
Epoch 60/100
8/8 [==============================] - 1s 145ms/step - loss: 0.0411 - accuracy: 0.9796 - val_loss: 0.3850 - val_accuracy: 0.8226
Epoch 61/100
8/8 [==============================] - 1s 143ms/step - loss: 0.0720 - accuracy: 0.9673 - val_loss: 0.5091 - val_accuracy: 0.7742
Epoch 62/100
8/8 [==============================] - 1s 133ms/step - loss: 0.0228 - accuracy: 0.9878 - val_loss: 0.5382 - val_accuracy: 0.8065
Epoch 63/100
8/8 [==============================] - 1s 135ms/step - loss: 0.0166 - accuracy: 0.9959 - val_loss: 0.5027 - val_accuracy: 0.8387
Epoch 64/100
8/8 [==============================] - 1s 140ms/step - loss: 0.0239 - accuracy: 0.9878 - val_loss: 0.4835 - val_accuracy: 0.8548
Epoch 65/100
8/8 [==============================] - 1s 141ms/step - loss: 0.0501 - accuracy: 0.9878 - val_loss: 0.4564 - val_accuracy: 0.8387
Epoch 66/100
8/8 [==============================] - 1s 162ms/step - loss: 0.0212 - accuracy: 0.9878 - val_loss: 0.4216 - val_accuracy: 0.8387
Epoch 67/100
8/8 [==============================] - 1s 169ms/step - loss: 0.0235 - accuracy: 0.9918 - val_loss: 0.3966 - val_accuracy: 0.8548
Epoch 68/100
8/8 [==============================] - 1s 155ms/step - loss: 0.0395 - accuracy: 0.9918 - val_loss: 0.4118 - val_accuracy: 0.8226
Epoch 69/100
8/8 [==============================] - 1s 173ms/step - loss: 0.0442 - accuracy: 0.9837 - val_loss: 0.4085 - val_accuracy: 0.8548
Epoch 70/100
8/8 [==============================] - 1s 164ms/step - loss: 0.0220 - accuracy: 0.9918 - val_loss: 0.4225 - val_accuracy: 0.8548
Epoch 71/100
8/8 [==============================] - 1s 162ms/step - loss: 0.0195 - accuracy: 0.9959 - val_loss: 0.3731 - val_accuracy: 0.8710
Epoch 72/100
8/8 [==============================] - 1s 172ms/step - loss: 0.0095 - accuracy: 0.9959 - val_loss: 0.3665 - val_accuracy: 0.8387
Epoch 73/100
8/8 [==============================] - 1s 144ms/step - loss: 0.0125 - accuracy: 1.0000 - val_loss: 0.3816 - val_accuracy: 0.8548
Epoch 74/100
8/8 [==============================] - 1s 160ms/step - loss: 0.0411 - accuracy: 0.9837 - val_loss: 0.4743 - val_accuracy: 0.8387
Epoch 75/100
8/8 [==============================] - 1s 155ms/step - loss: 0.0286 - accuracy: 0.9918 - val_loss: 0.5085 - val_accuracy: 0.8387
Epoch 76/100
8/8 [==============================] - 1s 147ms/step - loss: 0.0501 - accuracy: 0.9837 - val_loss: 0.5491 - val_accuracy: 0.8548
Epoch 77/100
8/8 [==============================] - 1s 153ms/step - loss: 0.0445 - accuracy: 0.9878 - val_loss: 0.5381 - val_accuracy: 0.8548
Epoch 78/100
8/8 [==============================] - 1s 180ms/step - loss: 0.0085 - accuracy: 0.9959 - val_loss: 0.5165 - val_accuracy: 0.8548
Epoch 79/100
8/8 [==============================] - 1s 152ms/step - loss: 0.0455 - accuracy: 0.9796 - val_loss: 0.4389 - val_accuracy: 0.8871
Epoch 80/100
8/8 [==============================] - 1s 177ms/step - loss: 0.0348 - accuracy: 0.9837 - val_loss: 0.4085 - val_accuracy: 0.8871
Epoch 81/100
8/8 [==============================] - 1s 154ms/step - loss: 0.0270 - accuracy: 0.9918 - val_loss: 0.3844 - val_accuracy: 0.8871
Epoch 82/100
8/8 [==============================] - 1s 173ms/step - loss: 0.0543 - accuracy: 0.9837 - val_loss: 0.3827 - val_accuracy: 0.9194
Epoch 83/100
8/8 [==============================] - 1s 187ms/step - loss: 0.0617 - accuracy: 0.9796 - val_loss: 0.4587 - val_accuracy: 0.8871
Epoch 84/100
8/8 [==============================] - 1s 176ms/step - loss: 0.0086 - accuracy: 1.0000 - val_loss: 0.5762 - val_accuracy: 0.8710
Epoch 85/100
8/8 [==============================] - 2s 210ms/step - loss: 0.0156 - accuracy: 0.9959 - val_loss: 0.5980 - val_accuracy: 0.8710
Epoch 86/100
8/8 [==============================] - 2s 196ms/step - loss: 0.0386 - accuracy: 0.9878 - val_loss: 0.4881 - val_accuracy: 0.8871
Epoch 87/100
8/8 [==============================] - 1s 184ms/step - loss: 0.0195 - accuracy: 0.9959 - val_loss: 0.5051 - val_accuracy: 0.9032
Epoch 88/100
8/8 [==============================] - 1s 188ms/step - loss: 0.0117 - accuracy: 0.9959 - val_loss: 0.5978 - val_accuracy: 0.8387
Epoch 89/100
8/8 [==============================] - 2s 201ms/step - loss: 0.0134 - accuracy: 0.9959 - val_loss: 0.7096 - val_accuracy: 0.7903
Epoch 90/100
8/8 [==============================] - 2s 207ms/step - loss: 0.0282 - accuracy: 0.9878 - val_loss: 0.8517 - val_accuracy: 0.7258
Epoch 91/100
8/8 [==============================] - 1s 195ms/step - loss: 0.0242 - accuracy: 0.9918 - val_loss: 1.0758 - val_accuracy: 0.6613
Epoch 92/100
8/8 [==============================] - 1s 190ms/step - loss: 0.0252 - accuracy: 0.9959 - val_loss: 0.9523 - val_accuracy: 0.6774
Epoch 93/100
8/8 [==============================] - 1s 201ms/step - loss: 0.0137 - accuracy: 0.9959 - val_loss: 0.7976 - val_accuracy: 0.7097
Epoch 94/100
8/8 [==============================] - 1s 199ms/step - loss: 0.0155 - accuracy: 0.9959 - val_loss: 0.7096 - val_accuracy: 0.7742
Epoch 95/100
8/8 [==============================] - 1s 176ms/step - loss: 0.0236 - accuracy: 0.9918 - val_loss: 0.5864 - val_accuracy: 0.8387
Epoch 96/100
8/8 [==============================] - 1s 177ms/step - loss: 0.0190 - accuracy: 0.9918 - val_loss: 0.5050 - val_accuracy: 0.8548
Epoch 97/100
8/8 [==============================] - 1s 174ms/step - loss: 0.0128 - accuracy: 0.9959 - val_loss: 0.4799 - val_accuracy: 0.8710
Epoch 98/100
8/8 [==============================] - 1s 170ms/step - loss: 0.0228 - accuracy: 0.9959 - val_loss: 0.4938 - val_accuracy: 0.8548
Epoch 99/100
8/8 [==============================] - 2s 208ms/step - loss: 0.0071 - accuracy: 1.0000 - val_loss: 0.5290 - val_accuracy: 0.8226
Epoch 100/100
8/8 [==============================] - 2s 200ms/step - loss: 0.0247 - accuracy: 0.9878 - val_loss: 0.5363 - val_accuracy: 0.8065
CPU times: user 3min 57s, sys: 2min 12s, total: 6min 9s
Wall time: 1min 59s

Performance during model fitting

Let's take a look at the loss and accuracy values during the different epochs, starting with accuracy values:

In [46]:
fig = plt.figure(figsize=(10, 4))
epoch = np.arange(nEpochs) + 1
fontsize = 16
plt.plot(epoch, fit.history['accuracy'], marker="o", linewidth=2,
         color="steelblue", label="acc")
plt.plot(epoch, fit.history['val_accuracy'], marker="o", linewidth=2,
         color="orange", label="val_acc")
plt.xlabel('epoch', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(frameon=False, fontsize=16);

Given that we are running this interactively in a jupyter notebook, we can make use of its capabilities and create an interactive graph using plotly:

In [47]:
def accuracy_epoch_plotly():

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=epoch, y=fit.history['accuracy'],
                        mode='lines+markers',
                        name='acc'))
    fig.add_trace(go.Scatter(x=epoch, y=fit.history['val_accuracy'],
                        mode='lines+markers',
                        name='val_acc'))

    fig.update_layout(
       title={
            'text': "Accuracy per epoch",
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'},
        xaxis_title="Epoch",
        yaxis_title="Accuracy",
        legend_title="Type",
        template='plotly_white'
    )

    fig.show()
In [48]:
accuracy_epoch_plotly()

Next, we check the loss values, at first via a static plot:

In [49]:
fig = plt.figure(figsize=(10, 4))
epoch = np.arange(nEpochs) + 1
fontsize = 16
plt.plot(epoch, fit.history['loss'], marker="o", linewidth=2,
         color="steelblue", label="loss")
plt.plot(epoch, fit.history['val_loss'], marker="o", linewidth=2,
         color="orange", label="val_loss")
plt.xlabel('epoch', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(frameon=False, fontsize=16);

and second via an interactive plot:

In [50]:
def loss_epoch_plotly():

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=epoch, y=fit.history['loss'],
                        mode='lines+markers',
                        name='loss'))
    fig.add_trace(go.Scatter(x=epoch, y=fit.history['val_loss'],
                        mode='lines+markers',
                        name='val_loss'))

    fig.update_layout(
       title={
            'text': "Loss per epoch",
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'},
        xaxis_title="Epoch",
        yaxis_title="Loss",
        legend_title="Type",
        template='plotly_white'
    )

    fig.show()
In [51]:
loss_epoch_plotly()

Great, it seems that accuracy is constantly increasing and the loss is continuing to drop. But how well is our model doing on the test data?

Evaluating the model

In [52]:
evaluation = model.evaluate(X_zscore_te, y_test)
print('Loss in Test set:      %.02f' % (evaluation[0]))
print('Accuracy in Test set:  %.02f' % (evaluation[1] * 100))
3/3 [==============================] - 0s 14ms/step - loss: 0.6017 - accuracy: 0.8182
Loss in Test set:      0.60
Accuracy in Test set:  81.82

 Confusion Matrix

We can also evaluate the model in more detail via obtaining a confusion matrix which will provide as with more information concerning the sensitivity and specificity of our model. After getting the predicted and true labels

In [53]:
y_pred = np.argmax(model.predict(X_zscore_te), axis=1)
y_pred
Out[53]:
array([1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0,
       0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1,
       0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1])
In [54]:
y_true = y_test * 1
y_true
Out[54]:
array([1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1])

We can compute the confusion matrix and plot it:

In [55]:
from sklearn.metrics import confusion_matrix
import pandas as pd
In [56]:
class_labels = ['closed', 'open']
cm = pd.DataFrame(confusion_matrix(y_true, y_pred), index=class_labels, columns=class_labels)
sns.heatmap(cm, square=True, annot=True);

Again, an interactive plot might be nice as well:

In [57]:
def confusion_matrix_plotly():
    
    z_text = [[str(y) for y in x] for x in cm.to_numpy()]
    fig = ff.create_annotated_heatmap(cm.to_numpy(), x=class_labels, y=class_labels, annotation_text=z_text, colorscale='Magma')

    # add custom xaxis title
    fig.add_annotation(dict(font=dict(color="black",size=14),
                            x=0.5,
                            y=-0.15,
                            showarrow=False,
                            text="Predicted value",
                            xref="paper",
                            yref="paper"))

    # add custom yaxis title
    fig.add_annotation(dict(font=dict(color="black",size=14),
                            x=-0.1,
                            y=0.45,
                            showarrow=False,
                            text="Real value",
                            textangle=-90,
                            xref="paper",
                            yref="paper"))
    fig.show()
In [58]:
confusion_matrix_plotly()

Analyze prediction values

What are the predicted values of the test set?

In [59]:
y_pred = model.predict(X_zscore_te)
y_pred[:10,:]
Out[59]:
array([[6.3193397e-04, 9.9936813e-01],
       [7.3998719e-01, 2.6001284e-01],
       [2.6101706e-04, 9.9973899e-01],
       [9.9992204e-01, 7.7959899e-05],
       [4.1116646e-04, 9.9958879e-01],
       [9.9938703e-01, 6.1293563e-04],
       [9.6927714e-01, 3.0722868e-02],
       [2.5064608e-03, 9.9749351e-01],
       [9.6933562e-01, 3.0664392e-02],
       [2.3600699e-01, 7.6399302e-01]], dtype=float32)

As you can see, those values can be between 0 and 1.

In [60]:
fig = plt.figure(figsize=(6, 4))
fontsize = 16
plt.hist(y_pred[:,0], bins=16, label='eyes closed')
plt.hist(y_pred[:,1], bins=16, label='eyes open');
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(frameon=False, fontsize=16);

As usual, we also generate an interactive plot:

In [61]:
fig = go.Figure()
fig.add_trace(go.Histogram(x=y_pred[:,0],name='eyes closed', nbinsx=16, marker_color='blue'))
fig.add_trace(go.Histogram(x=y_pred[:,1],name='eyes open', nbinsx=16, marker_color='orange'))

fig.update_layout(barmode='stack', template='plotly_white')

fig.show()

The more both distributions are distributed around chance level, the weaker your model is.

Note: Keep in mind that we trained the whole model only on one split of test and training data. Ideally, you would repeat this process many times so that your results become less dependent on what kind of split you did.

Visualizing Hidden Layers

Finally, as a cool additional feature: We can now visualize the individual filters of the hidden layers. So let's get to it:

In [62]:
# Aggregate the layers
layer_dict = dict([(layer.name, layer) for layer in model.layers])
In [85]:
from tensorflow.keras import backend as K

# Specify a function that visualized the layers
def show_activation(layer_name):
    
    layer_output = layer_dict[layer_name].output

    fn = K.function([model.input], [layer_output])
    
    inp = X_train[0:1]
    
    this_hidden = fn([inp])[0]
    
    # plot the activations, 8 filters per row
    plt.figure(figsize=(16,8))
    nFilters = this_hidden.shape[-1]
    nColumn = 8 if nFilters >= 8 else nFilters
    for i in range(nFilters):
        plt.subplot(int(nFilters / int(nColumn)), int(nColumn), i+1)
        plt.imshow(this_hidden[0,:,:,i], cmap='magma', interpolation='nearest')
        plt.axis('off')
    
    return 

Now we can plot the filters of the hidden layers:

In [86]:
layer_dict
Out[86]:
{'conv2d': <keras.layers.convolutional.Conv2D at 0x182355360>,
 'batch_normalization': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x182364d90>,
 'max_pooling2d': <keras.layers.pooling.MaxPooling2D at 0x182356260>,
 'conv2d_1': <keras.layers.convolutional.Conv2D at 0x1823576a0>,
 'batch_normalization_1': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x182355db0>,
 'max_pooling2d_1': <keras.layers.pooling.MaxPooling2D at 0x182355750>,
 'conv2d_2': <keras.layers.convolutional.Conv2D at 0x1824b9de0>,
 'batch_normalization_2': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824b9c60>,
 'max_pooling2d_2': <keras.layers.pooling.MaxPooling2D at 0x1824bb160>,
 'flatten': <keras.layers.core.flatten.Flatten at 0x1824bb0a0>,
 'dropout': <keras.layers.core.dropout.Dropout at 0x1824bb910>,
 'dense': <keras.layers.core.dense.Dense at 0x1824d9150>,
 'batch_normalization_3': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824b94b0>,
 'dropout_1': <keras.layers.core.dropout.Dropout at 0x1824d9930>,
 'dense_1': <keras.layers.core.dense.Dense at 0x1824d9510>,
 'batch_normalization_4': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824db130>,
 'dropout_2': <keras.layers.core.dropout.Dropout at 0x1824f0cd0>,
 'dense_2': <keras.layers.core.dense.Dense at 0x1824db340>,
 'batch_normalization_5': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824f1060>,
 'dropout_3': <keras.layers.core.dropout.Dropout at 0x1824f2ef0>,
 'dense_3': <keras.layers.core.dense.Dense at 0x1824f30d0>}
In [87]:
show_activation('conv2d_1')
In [88]:
show_activation('conv2d_2')

Conclusion of 2D example

The classification of the training set gets incredibly high, while the validation set also reaches a reasonable accuracy level above 80. Nonetheless, by only investigating a slab of our fMRI dataset, we might have missed out on some important additional parameters.

An alternative solution might be to use 3D convoluted neural networks. But keep in mind that they will have even more parameters and probably take much longer to fit the model to the training data. Having said so, let's get to it.

Super secret fancy surprise

Going back to graphics, outputs and interactive instances of jupyter notebook, we can even go crazier and actually include a running tensorboard instance to enable interactive evaluation of our model (the future is now):

In [92]:
%load_ext tensorboard
%tensorboard --logdir logs
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard